import ekf_testv6 as ekf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as ss
import summary_of_sims as sum_sims
from statsmodels.graphics import tsaplots
import matplotlib.transforms as mtransforms

matsh=10
eval_points=np.arange(-.25,.15,.001)
#order: running 30-year mean, OCN, Butterworth, change point, EKF, RTS, coupled sims average, blind model
#order to simply have a smaller range for things that don't apply

labels=["Running Average $\overline{_{30}Y_n}$", "OCN Chunks", "Butterworth Smoothed", "Change Point Lines", "EBM-KF-uf $\\hat {T }_n$", "EBM-KF-ta", "EBM-KF-ca", \
        "LENS2 $\overline{(Y_t)_j}$", "Blind Model $\~{T}_{t}$"]
shlabels=["RunAvg \n $\overline{_{30}Y_t}$", "OCN \n Chunks", "ButW \n Smooth", "$\\Delta$Pt \n Lines", "EBM-KF-uf \n $\\hat {T }_t$", "EBM-KF-ta",  "RTS \n $ \\hat \hat {T }_t$", \
        "LENS2 \n $\overline{(Y_t)_j}$", "Blind \n $\~{T}_{t}$", "HadCRUT5 \n ${Y}_{t}$"]
colorlabels=["yellow","limegreen","springgreen", ekf.pcolor, ekf.colorekf,"darkolivegreen", "magenta" ,sum_sims.cesmcolor,"darkgoldenrod",ekf.colorgrey]
oneswt=np.ones(ekf.n_iters)
values = np.empty((matsh,ekf.n_iters))
values[:] = np.nan
stdevs= np.empty((matsh-2,ekf.n_iters))
stdevs[:] = np.nan
moderrs=np.empty((matsh-2,ekf.n_iters))
moderrs[:] = np.nan
temps=ekf.temps

#compute running 30 yr mean
print(len(ekf.temps)==ekf.n_iters)

values[0,:]=ekf.moving_aves
stdevs[0,:]=ekf.std_aves
moderrs[0,:]=ekf.std_aves/np.sqrt(ekf.N)


#ekf
values[4,:]=np.copy(ekf.xh1s)
stdevs[4,:]=np.copy(ekf.stdS)
moderrs[4,:]=np.copy(ekf.stdP)
#stdevs[4,0:2]=np.nan


#simulations
#lenss=sum_sims.smyrs
lenss1=174
values[7,:lenss1]=sum_sims.twTRmean[:lenss1]
stdevs[7,:lenss1]=sum_sims.twTRstd[:lenss1]
lenss=165
CMIP6=np.genfromtxt(open("summaryCMIP6.csv", "rb"),dtype=float, delimiter=',')
moderrs[7,:lenss]=CMIP6[20,:]
#blind model
values[8,:]=ekf.xblind[:,0]

values[9,:]=temps

##TRAILING average
wt_opt_depths = np.copy(1/(ekf.opt_depth+9.7279))
N = 30
nwt_opt_depths=np.empty(len(ekf.opt_depth)); nwt_opt_depths[:]=ekf.involcavg
cN=int(np.ceil(N/2))
fN=int(np.floor(N/2))
for i in range((fN),(len(nwt_opt_depths)-1)):
    lasta=i+cN;firsta=i-fN
    nwt_opt_depths[i] = (np.sum(wt_opt_depths[(firsta):i+1])+ ekf.involcavg*(cN-1))/N
    #computing half-average - future is assumed to be the average 
nopt_depths=(1/nwt_opt_depths-9.7279)
ekf.opt_depth=nopt_depths
(trailfilter_xhat,P2s,S2s,trailfilter_xhatm)=ekf.ekf_run(ekf.observ,ekf.n_iters,retPs=2)
xh2s=np.copy(trailfilter_xhat[:,0])
stdP2s=np.sqrt(np.abs(P2s))[:,0,0]
values[5,:]=xh2s
stdevs[5,:]=stdP2s
moderrs[5,:]=np.sqrt(np.abs(S2s))[:,0,0]



#CENTERED average
N = 30
nwt2_opt_depths=np.empty(len(ekf.opt_depth)); nwt2_opt_depths[:]=ekf.involcavg
cN=int(np.ceil(N/2))
fN=int(np.floor(N/2))
for i in range((fN),(len(nwt_opt_depths)-cN+1)):
    lasta=i+cN;firsta=i-fN
    nwt2_opt_depths[i] = (np.sum(wt_opt_depths[firsta:lasta]))/N 
nopt2_depths=(1/nwt2_opt_depths-9.7279)
ekf.opt_depth=nopt2_depths
(centfilter_xhat,P2sa,S2sa,centfilter_xhatm)=ekf.ekf_run(ekf.observ,ekf.n_iters-15,retPs=2)
values[6,:-15]=np.copy(centfilter_xhat[:,0])





#absolute difference from real temps
figh = plt.figure(figsize=(12,4))
fig2h,axl=plt.subplots(figsize=(6,4))
#figh.subplots_adjust(wspace=0.4)
gs = figh.add_gridspec(1, 4,  width_ratios=(4,.7,4,1),
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.2, hspace=0.05)
axla=figh.add_subplot(gs[0, 0])
axp=figh.add_subplot(gs[0, 2])
axp.set_ylim([-0.225,0.1])
ax_histy = figh.add_subplot(gs[0, 3], sharey=axp)
binwidth = 0.005
xymax = 0.21
lim = (int(xymax/binwidth) + 1) * binwidth
bins = np.arange(-lim, lim + binwidth, binwidth)

diffptsEKF=-values[4,:-1]+temps[0:-1]
derivptsEKF=values[4,1:]-values[4,:-1]

diffpts30m=-values[0,15:-16]+temps[15:-16]
derivpts30m= values[0,16:-15]-values[0,15:-16]
axp.plot(ekf.dates[1:],derivptsEKF,'.',color=colorlabels[4],label="EBM-KF-uf",zorder=2, markersize=6)
#ax_histy.hist(derivptsEKF[1:], bins=bins, histtype=u'step',orientation='horizontal',color=colorlabels[4],zorder=2)
ax_histy.plot(ss.gaussian_kde(derivptsEKF[1:]).pdf(eval_points),eval_points,color=colorlabels[4])
axp.plot(ekf.dates[16:-15],derivpts30m,'.',color='yellow', markeredgecolor=colorlabels[1],label="30-year running mean",zorder=3,alpha=1, markersize=6,markeredgewidth=1)
#ax_histy.hist(derivpts30m, bins=bins,orientation='horizontal',color=colorlabels[0],zorder=1,alpha=0.5)
deriv30hist = ss.gaussian_kde(derivpts30m).pdf(eval_points)
ax_histy.plot(deriv30hist,eval_points,color=colorlabels[1])
ax_histy.plot(deriv30hist,eval_points,'--',color='yellow',linewidth=0.8)
##axp.set_xlabel("Innovation from Mean/State to Measurement")
axp.set_ylabel("Change in Temperature State (°C/yr)")
axp.set_title("Comparison of Annual Temperature Changes")
##
##legend1=plt.legend(loc="lower right")
##
noteYears = [[1991,1992,1993,1994,1995,1996,1997,1998],[1963,1964,1965],[1902,1903,1904],[1883,1884,1885],
		[1856,1857,1858]]
noteVolcs = ["Mt. Pinatubo, \nPhilippines", "Mt. Agung, \nIndonesia","Santa Maria, \nGuatemala", "Krakatoa, \nIndonesia", \
              "Mt. Awu, \nIndonesia"] # "Komaga-take, Japan" "Shiveluch, Russia"
symbols = ['^', '>', 's','P','*','X']




print("MAX INDEX")
print(np.argmax(diffpts30m))
sortedderivs = sorted(zip(ekf.dates[1:],derivptsEKF), key=lambda x:x[1])
print(sortedderivs[0:20])
r2 = ekf.r2_score(values[0,15:-16], values[4,15:-16])
print('r2 score for running mean is to EBM-KF', r2)


axlabels=['a)','b)','c)']

axs=(axla,axp, ax_histy)
for i in range(len(axs)):
    ax=axs[i]
    label=axlabels[i]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-40/72, 1/72, figh.dpi_scale_trans) #-20/72, 7/72
    if i==2:
        trans = mtransforms.ScaledTranslation(20/72, 1/72, figh.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')



#graparr=[[4,0,1,2],[4,3,6],[4,5,7]]


axl.plot(ekf.dates,xh2s-ekf.pindavg,'-',label='EBM-KF-ta', color='darkolivegreen',zorder=-1)
#np.max(np.abs(xh2s[15:-16]-values[0,15:-16]))
#values[0,15:-16], xh2s[15:-16]



axp.plot(ekf.dates[2:],xh2s[2:]-xh2s[1:-1],'.',color='darkolivegreen',label="EBM-KF-ta")
#ax_histy.hist(xh2s[2:]-xh2s[1:-1], bins=bins,histtype=u'step', orientation='horizontal',color="darkolivegreen",zorder=2)
ax_histy.plot(ss.gaussian_kde(xh2s[2:]-xh2s[1:-1]).pdf(eval_points),eval_points,color="darkolivegreen",zorder=5)
#axp.plot(ekf.dates[1:],sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1],'.',color=colorlabels[6],label=shlabels[6],zorder=0)
#ax_histy.hist(sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1], bins=bins,orientation='horizontal',color=colorlabels[6],zorder=0)
#ax_histy.plot(ss.gaussian_kde(sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1]).pdf(eval_points),eval_points,color=colorlabels[6],zorder=0)
axp.set_xlabel("Year")
ax_histy.set_xlabel("# of Years")
r2 = ekf.r2_score(values[0,15:-16], xh2s[15:-16])
print('r2 score for running mean is to EBM-KF-ta', r2)


toplot=[4,6]
for i in toplot:
    axl.plot(ekf.dates,values[i,:]-ekf.pindavg,'-',label=labels[i], color=colorlabels[i])
axl.plot(ekf.dates,values[0,:]-ekf.pindavg,'-',label=labels[0], color=colorlabels[1])
axl.plot(ekf.dates,values[0,:]-ekf.pindavg,'--',label=labels[0], color="yellow",linewidth=0.8)

ekf.plot_boilerplate(axl)
axl.set_yticks(np.arange(-0.4,1.8,0.2))
axl.set_ylabel('$\\Delta$ Temperature (°C)\nAbs(K)=$\\Delta(T)$ + 286.7K')
axl.set_ylim(286.5-ekf.pindavg,287.7-ekf.pindavg)



dotted_line1 = plt.Line2D([], [],  linestyle="--", color="yellow", linewidth=0.8)
dotted_line2 = plt.Line2D([], [],  linestyle="-", color=colorlabels[1])

handles0, labels = axl.get_legend_handles_labels()
handles0[3] = (dotted_line2, dotted_line1)
order=[0,3,2,1]
axl.legend( [handles0[i] for i in order], [labels[i] for i in order], loc='upper left')

axp.set_xticks(np.arange(1850,2025+1,25))
axp.set_xlim(1850,2025)
shiftup=[0,0,0,0.03,0]
shift0=[0, 0, 0.4, 0.2 ,0.4]
for vi in range(len(noteYears)):
    v=noteYears[vi]
    axl.plot([v[0],v[0]],[-0.5,1.7],':',color='0.7',zorder=0)
    axl.annotate( noteVolcs[vi], (v[0], shift0[vi]),fontsize=8)
    axp.plot([v[0],v[0]],[-0.5,1.7],':',color='0.7',zorder=0)#[286,289]
    axp.annotate( noteVolcs[vi], (v[0], -0.2+shiftup[vi]),fontsize=8)


axb = axl.twinx()
axb.set_yticks(np.arange(12.4,18.4,0.2))
axb.set_ylim(286.5- 273.15, 287.7- 273.15)
axb.set_ylabel('Temperature (°C)')





plt.figure()
plt.plot(ekf.dates, 1/wt_opt_depths-9.7279, color=ekf.colorekf)
plt.plot(ekf.dates, 1/nwt_opt_depths-9.7279, color='darkolivegreen')
plt.plot(ekf.dates, 1/nwt2_opt_depths-9.7279, color='magenta')

plt.show()





